from binary_tree import TreeNode


class Solution:
    def invertTree(self, root: TreeNode) -> TreeNode:
        def dfs(node: TreeNode):
            if not node:
                return
            if node.left or node.right:
                left = node.left
                node.left = node.right
                node.right = left
                dfs(node.left)
                dfs(node.right)

        dfs(root)
        return root
